#pragma once

#include "functions.h"

class esn_network 
{
private:
	int nResUnits, nOutputs, nInputs; 
	double connectivity, resScaling;
	gsl_matrix *W, *W_in, *W_out, *W_fb;
	gsl_vector *x, *y;
	bool useFeedback;
public:
	esn_network(int nResUnits, int nOutputs, int nInputs, double connectivity = 0.10, double resScaling = 0.95, double inScaling = 0.5, double outScaling = 0.5, double fbScaling = 0.5, bool useFeedback = false) : nResUnits(nResUnits), nOutputs(nOutputs), nInputs(nInputs), connectivity(connectivity), resScaling(resScaling), useFeedback(useFeedback)
	{
		x = gsl_vector_calloc(nResUnits);
		y = gsl_vector_calloc(nOutputs);
		W = create_reservoir(nResUnits, connectivity, resScaling);
		W_in = NULL;
		if(nInputs > 0){
			W_in = create_matrix(nResUnits, nInputs, inScaling);
		}
		W_out = create_matrix(nOutputs, nResUnits, outScaling); //gsl_matrix_alloc(nOutputs, nResUnits);
		W_fb = create_matrix(nResUnits, nOutputs, fbScaling); //gsl_matrix_alloc(nResUnits, nOutputs);
	}
	~esn_network()
	{
		gsl_vector_free(x);
		gsl_vector_free(y);
		gsl_matrix_free(W);
		if(nInputs > 0){
			gsl_matrix_free(W_in);
		}
		gsl_matrix_free(W_out);
		gsl_matrix_free(W_fb);
	}
	gsl_matrix *create_reservoir(int nResUnits, double connectivity, double resScaling);
	void print_reservoir();
	gsl_matrix *get_matrix_W();
	gsl_matrix *get_matrix_Win();
	gsl_matrix *get_matrix_Wout();
	gsl_matrix *get_matrix_Wfb();
	void activationFunction(gsl_vector *v);
	void activationFunction_inv(gsl_vector *v);
	void train_offline(gsl_matrix *input, gsl_matrix *desired, unsigned int stateCollectingOfsset = 0);
	void train_offline(gsl_matrix *input, gsl_matrix *desired, unsigned int stateCollectEach, unsigned int stateCollectStartAt);
	
	gsl_matrix *runNetwork(gsl_matrix *input, gsl_matrix *desired, int steps, int forcing_steps);
	
	void saveNetwork();
	void loadNetwork();
};
